import numpy as np
import pandas as pd
import os
import sys
import pickle
from joblib import dump, load
import yaml
import statsmodels.api as sm
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import RandomizedSearchCV
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.linear_model import LinearRegression as reg
import pdb
from pprint import pprint
from statistics import mean
import random
random.seed(50)
import math
import argparse
sys.path.insert(1,'/REDACTED/fairness/code/rf/scripts/rf')
from gpd_test import *

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib import cm
import matplotlib.font_manager as fm
import matplotlib
import sklearn.metrics as metrics
import seaborn as sns
import pdb
#import trajectoryPlotters as tp
sys.path.insert(1,'/REDACTED/fairness/code/utilities/')
from trajectoryPlotters import *



defaultaxpayer_id= '/REDACTED/data/modeled_refactor_temp/'
defaultout = '/REDACTED/'


plt.style.use('/REDACTED/fairness/code/config/fairness.mplstyle')
fe = fm.FontEntry(
    fname='/REDACTED/fairness/code/utilities/cmunrm.ttf',
    name='latex')
fm.fontManager.ttflist.insert(0, fe)
matplotlib.rcParams['font.family'] = fe.name
plt.rcParams['mathtext.default'] = 'regular'

stanford_palette = ['#f0f4f5', '#d0d8da', '#aabeC6', '#009abb', '#007c92', '#09425a',
                   '#c7d1c6', '#80982a', '#556222', '#b96d12', '#53284f', '#5e3032',
                   '#8c1515', '#dddddd', '#cccccc', '#999999', '#666666', '#333333',
                   '#000000']
color_codes = ['background blue', 'gray blue', 'dark gray-blue', 'accent blue', 'link blue', 'dark blue',
              'light green', 'bright green', 'dark green', 'orange', 'purple', 'maroon',
              'cardinal red', 'line gray', 'light gray', 'gray', 'med gray', 'text gray',
              'black']
cdict = dict(zip(color_codes, [mcolors.to_rgba(c) for c in stanford_palette]))

mcolors.get_named_colors_mapping().update(cdict)

###############################################
#### GET PLOT DICTIONARIES
###############################################


with open('/REDACTED/fairness/code/rf/data/status_quo.pickle','rb') as f:
    status_quo_point = pickle.load(f)

with open('/REDACTED/data/metadata/audit_shares_eitc_ret.pickle','rb') as f:
    status_quo_eitc = pickle.load(f)

with open('/REDACTED/data/metadata/audit_shares_within_eitc_ret.pickle','rb') as f:
    status_quo_within_eitc_tpi = pickle.load(f)

with open('/REDACTED/data/metadata/audit_shares_ret.pickle','rb') as f:
    status_quo_tpi = pickle.load(f)


#### Classifer
with open(defaultaxpayer_id + 'eitc_unres_cls_100_plus_dep_database_output_outcome_chg_in_tax_owed_pv_new_bifsg.pickle', 'rb') as f:
    unres_eitc_cls_100_dep_database = pickle.load(f)

with open(defaultaxpayer_id + 'eitc_res_cls_100_plus_dep_database_output_outcome_chg_in_tax_owed_pv_new_bifsg.pickle', 'rb') as f:
    res_eitc_cls_100_dep_database = pickle.load(f)

#### different classifer threshold
## below: use cls_1 not cls_0, this is becuase code is set up >=threshold, so threshold of 0 does not work
with open(defaultaxpayer_id + 'eitc_res_cls_1_plus_dep_database_output_new_bifsg.pickle', 'rb') as f:
    res_eitc_cls_0_dep_database = pickle.load(f)

with open(defaultaxpayer_id + 'eitc_unres_cls_1_plus_dep_database_output_new_bifsg.pickle', 'rb') as f:
    unres_eitc_cls_0_dep_database = pickle.load(f)


with open(defaultaxpayer_id + 'eitc_res_cls_50_plus_dep_database_output_new_bifsg.pickle', 'rb') as f:
    res_eitc_cls_50_dep_database = pickle.load(f)

with open(defaultaxpayer_id + 'eitc_unres_cls_50_plus_dep_database_output_new_bifsg.pickle', 'rb') as f:
    unres_eitc_cls_50_dep_database = pickle.load(f)

with open(defaultaxpayer_id + 'eitc_res_cls_500_plus_dep_database_output_new_bifsg.pickle', 'rb') as f:
    res_eitc_cls_500_dep_database = pickle.load(f)

with open(defaultaxpayer_id + 'eitc_unres_cls_500_plus_dep_database_output_new_bifsg.pickle', 'rb') as f:
    unres_eitc_cls_500_dep_database = pickle.load(f)

with open(defaultaxpayer_id + 'eitc_res_cls_1000_plus_dep_database_output_new_bifsg.pickle', 'rb') as f:
    res_eitc_cls_1000_dep_database = pickle.load(f)

with open(defaultaxpayer_id + 'eitc_unres_cls_1000_plus_dep_database_output_new_bifsg.pickle', 'rb') as f:
    unres_eitc_cls_1000_dep_database = pickle.load(f)

with open(defaultaxpayer_id + 'eitc_res_cls_2000_plus_dep_database_output_new_bifsg.pickle', 'rb') as f:
    res_eitc_cls_2000_dep_database = pickle.load(f)

with open(defaultaxpayer_id + 'eitc_unres_cls_2000_plus_dep_database_output_new_bifsg.pickle', 'rb') as f:
    unres_eitc_cls_2000_dep_database = pickle.load(f)

file_list = [unres_eitc_cls_0_dep_database, 
             unres_eitc_cls_50_dep_database,
             unres_eitc_cls_100_dep_database,
             unres_eitc_cls_500_dep_database,
             unres_eitc_cls_1000_dep_database,
             unres_eitc_cls_2000_dep_database]
            
string_list = ['unres_eitc_cls_0_dep_database',
               'unres_eitc_cls_50_dep_database',
               'unres_eitc_cls_100_dep_database', 
               'unres_eitc_cls_500_dep_database', 
               'unres_eitc_cls_1000_dep_database', 
               'unres_eitc_cls_2000_dep_database']

data_dict = {}

for i in range(len(file_list)):
    data_dict[string_list[i]] = getAllQuantities(file_list[i], bootstrap = True)

def makeDotPlot(modelnames,
                quantities,
                status_quo_point,
                colors,
                plotname,
                plot_sq=True,
                activity_code=None,
                x_varb='bootstrap_revenue',
                x_varb_sq='rev',
                x_lab='Detected Underreportaxpayer_idg ($ Millions)',
                y_varb='bootstrap_fair',
                y_varb_sq='chen_fair',
                se_varb='bootstrap_fair',
                y_lab='Black/Non-Black Disparity (percentage points)',
                demillion_x=True,
                demillion_y=False,
                scale_x=False,
                plot_ci = True,
                dotplot_symbol = None,
                budget='auto',
                out=defaultout,
                marker_size=10,
                eitc=True,
                offsets=[[0,0]],
                annotateModelName=True,
                annotationcoords=[[0,0]],
                sq_coords_eitc=[500,-0.001],
                sq_coords_pop=[1000,-0.0005]):
    ## Dot plot version of our trajectory plotters (for comparing across >3 models)
    fig, ax = plt.subplots(1, 1, sharex=True, sharey=True, figsize=(10,8))
    fig.suptitle('', fontsize=24)
    plt.xlabel(x_lab, fontsize=20)
    plt.ylabel(y_lab, fontsize=20)
    ## define coords for plot
    if budget == 'auto' and eitc == True and activity_code == None:
        idx=math.floor(status_quo_point['eitc_budget']*10000)-1
        xs = [quantities[i][x_varb + '_mean'][idx] for i in range(len(quantities))]
        x_errs = [quantities[i][x_varb + '_std'][idx] for i in range(len(quantities))]
        ys = [quantities[i][y_varb + '_mean'][idx] for i in range(len(quantities))]
        y_errs = [quantities[i][se_varb + '_std'][idx] for i in range(len(quantities))]
    elif budget == 'auto' and eitc == False and activity_code == None:
        aud_rate=math.floor(status_quo_point['overall_budget']*10000)-1
        xs = [quantities[i][x_varb + '_mean'][idx] for i in range(len(quantities))]
        x_errs = [quantities[i][x_varb + '_std'][idx] for i in range(len(quantities))]
        ys = [quantities[i][y_varb + '_mean'][idx] for i in range(len(quantities))]
        y_errs = [quantities[i][se_varb + '_std'][idx] for i in range(len(quantities))]
    elif activity_code is not None:
        idx=math.floor(status_quo_point[str(activity_code) + '_budget']*10000)-1
        xs = [quantities[i][x_varb + '_mean'][idx] for i in range(len(quantities))]
        x_errs = [quantities[i][x_varb + '_std'][idx] for i in range(len(quantities))]
        ys = [quantities[i][y_varb + '_mean'][idx] for i in range(len(quantities))]
        y_errs = [quantities[i][se_varb + '_std'][idx] for i in range(len(quantities))]
    else:
        raise Exception('Benchmark budgets now auto-fill in code. Set budget argument to \'auto\'.')
    models = modelnames
    label = str(budget) + '%'
    ax.set_xlabel(x_lab, fontsize=20)
    ax.set_ylabel(y_lab, fontsize=20)
    ax.set_title('', fontsize=20)
    ax.tick_params(labelsize=16)
    ax.axhline(y=0, color='black')
    if eitc == True and plot_sq == True:
        if demillion_x == True:
            x = status_quo_point['eitc_' + x_varb_sq]/1000000
        elif demillion_x == False:
            x = status_quo_point['eitc_' + x_varb_sq]
        if demillion_y == True:
            y = status_quo_point['eitc_' + y_varb_sq]/1000000
        elif demillion_y == False:
            y = status_quo_point['eitc_' + y_varb_sq]
        ax.axhline(y=y,color='red',linestyle='dotted')
        ax.axvline(x=x,color='red',linestyle='dotted')
        ax.plot(x,y, 'kx', markersize=10)
    elif eitc == False and plot_sq == True:
        if demillion_x == True:
            x = status_quo_point['overall_' + x_varb_sq]/1000000
        else:
            x = status_quo_point['overall_' + x_varb_sq]
        if demillion_y == True:
            y = status_quo_point['overall_' + y_varb_sq]/1000000
        else:
            y = status_quo_point['overall_' + y_varb_sq]
        ax.axhline(y=y,color='red',linestyle='dotted')
        ax.axvline(x=x,color='red',linestyle='dotted')
        ax.plot(x,y, 'kx', markersize=10)
    ## scale up y axis to percentages (instead of decimals)
    scale_y=0.01
    ticks_y=ticker.FuncFormatter(lambda y, pos: '{0:g}'.format(y/scale_y))
    ax.yaxis.set_major_formatter(ticks_y)
    ## scale x axis too if needed:
    if scale_x==True and demillion==False:
        scale_x=0.01
        ticks_x=ticker.FuncFormatter(lambda x, pos: '{0:g}'.format(x/scale_x))
        ax.xaxis.set_major_formatter(ticks_x)
    if eitc == True and plot_sq == True:
        ax.annotate('Status quo, ' + str(round(status_quo_point['eitc_budget']*100, 2)) + '% audit rate', xy=(x, y), xytext=(x+sq_coords_eitc[0]+220, y+sq_coords_eitc[1]-.0021), size=16)
    elif eitc == False and plot_sq == True:
        ax.annotate('Status quo, ' + str(round(status_quo_point['overall_budget']*100, 2)) + '\% audit rate', xy=(x, y), xytext=(x+sq_coords_pop[0], y+sq_coords_pop[1]), size=16)
    for i in range(len(models)):
        #x = np.array(quantities[i]['rev_mean'])
        #y = np.array(quantities[i]['fair_mean'])
        x = np.array(xs[i])
        y = np.array(ys[i])
        if dotplot_symbol is None: 
            ax.plot(x, y, 'o', label=models[i], color=colors[i], markersize=marker_size)
        if dotplot_symbol is not None:    
            ax.plot(x, y, marker =  dotplot_symbol[i], label=models[i], color=colors[i], markersize=marker_size)
        x_err = np.array(x_errs[i])
        y_err = np.array(y_errs[i])
        if plot_ci:    
            ax.errorbar(x,y,xerr=x_err*1.96,yerr=y_err*1.96, ecolor=colors[i], capsize=(marker_size+2))    
        if annotateModelName:
            ax.annotate(modelnames[i],(annotationcoords[i][0],annotationcoords[i][1]), size=16)
    plt.savefig(out+plotname+'.png')
    return fig


makeDotPlot(modelnames=['0', '50', '100', '500', '1000', '2000'],
                    quantities= [data_dict['unres_eitc_cls_0_dep_database'], data_dict['unres_eitc_cls_50_dep_database'], data_dict['unres_eitc_cls_100_dep_database'], data_dict['unres_eitc_cls_500_dep_database'], data_dict['unres_eitc_cls_1000_dep_database'], data_dict['unres_eitc_cls_2000_dep_database']],
                    status_quo_point = status_quo_point,
                    colors= ['magenta', 'red', 'blue', 'orange', 'green', 'black'],
                    plotname='dotplot_prob_combined_eitc_cls_dep_database',
                    budget = 'auto',
                    plot_ci = False,
                    dotplot_symbol = ['o', 'o', 'o', 'o','o', 'o'],
                    x_varb = 'bootstrap_revenue',
                    y_varb='bootstrap_chen_fair',
                    se_varb= 'bootstrap_chen_fair',
                    y_lab = 'Disparity (probabilistic, percentage points)',
                    y_varb_sq='chen_fair',
                    plot_sq = True,
                    out=defaultout,
                    marker_size = 8,
                    eitc=True,
                    annotateModelName=True,
                    sq_coords_eitc=[-200,0.001],
                    sq_coords_pop =[-200,0.001],
                    annotationcoords=[[2355, 0.01], [2790, 0.009], [2730,0.0115], [2730, 0.0135], [2800, 0.0143], [2600, 0.0175]])

plt.close()